Skip to content

Add --sliding_window flag to CoreML static LLM export#19250

Open
john-rocky wants to merge 1 commit intopytorch:mainfrom
john-rocky:coreml/sliding-window-static-llm
Open

Add --sliding_window flag to CoreML static LLM export#19250
john-rocky wants to merge 1 commit intopytorch:mainfrom
john-rocky:coreml/sliding-window-static-llm

Conversation

@john-rocky
Copy link
Copy Markdown

Summary

Models trained with sliding-window attention — Mistral 7B, Gemma 3, Gemma 4,
Llama 4 Scout, etc. — only need each layer to attend to the last W tokens.
export_static_llm_coreml.py was always sizing the per-layer KV cache to
max_context_len - input_len, so longer contexts were proportionally more
expensive in both KV cache memory and per-token attention compute even
though the model was trained to ignore everything outside the window.

Add a --sliding_window flag that caps the cache at the trained window.
The downstream pieces — StaticAttentionMask invariants under cache
eviction (validated by test_sliding_window_cache_and_mask) and
StaticAttentionIOManager's per-layer cache_lens plumbing — already
support this; the export script just needed to expose it.

The cache_len computation is factored into _resolve_cache_len so it is
unit-testable. Per-layer mixed sliding/full attention (Gemma 3 / Gemma 4
alternate sliding and full layers) is intentionally left for a follow-up;
this PR uses one window for every layer. Documented in the ANE
Optimizations section of readme.md.

Memory savings example

For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with
max_context_len=8192 in fp16:

Setting Per-layer KV cache (k+v) Total KV cache
Default (cache_len = 8160) 8160 × 8 × 128 × 2 × 2B = ~33 MB ~1.07 GB
--sliding_window 4096 4096 × 8 × 128 × 2 × 2B = ~16 MB ~0.54 GB

Test plan

Added unit tests in examples/apple/coreml/llama/test.py:

  • test_resolve_cache_len_no_sliding_window — default path is unchanged.
  • test_resolve_cache_len_with_sliding_window — cache shrinks to the window.
  • test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op — a
    user-provided window larger than the remaining context degenerates to the
    no-window case (so users can pass the model's training window verbatim).
  • test_resolve_cache_len_rejects_non_positive_window — input validation.
  • test_create_example_inputs_with_sliding_window_shrinks_kv_cache — full
    path: every cache tensor in the example inputs has its sequence dimension
    equal to the sliding window, and the attention mask covers
    input_len + sliding_window.
$ python -m pytest examples/apple/coreml/llama/test.py -v
test.py::test_split_model PASSED
test.py::test_resolve_cache_len_no_sliding_window PASSED
test.py::test_resolve_cache_len_with_sliding_window PASSED
test.py::test_resolve_cache_len_sliding_window_larger_than_context_is_a_no_op PASSED
test.py::test_resolve_cache_len_rejects_non_positive_window PASSED
test.py::test_create_example_inputs_with_sliding_window_shrinks_kv_cache PASSED
============================== 6 passed in 5.27s ===============================

I also confirmed the existing examples/models/llama/tests/test_static_attention.py::test_sliding_window_cache_and_mask already covers the cache + mask invariants under both shift_pointer and smart_mask eviction styles when cache_len < total_tokens, so this PR does not need to re-test that.

Authored with Claude.

Models trained with sliding-window attention (Mistral 7B, Gemma 3,
Gemma 4, Llama 4 Scout, …) only need each layer to attend to the
last `W` tokens, but `export_static_llm_coreml.py` was always
sizing the per-layer KV cache to `max_context_len - input_len`.
That made longer contexts proportionally more expensive in both KV
cache memory and per-token attention compute, even though the model
was trained to ignore everything outside the window.

Add a `--sliding_window` flag that caps the cache at the trained
window.  The downstream pieces — `StaticAttentionMask` invariants
under cache eviction and the `StaticAttentionIOManager`'s per-layer
`cache_lens` plumbing — already support this; the export script
just needed to expose it.  Per-layer mixed sliding/full attention
(Gemma 3/4) is left for a follow-up; this PR uses one window for
every layer.

The cache_len computation is factored into `_resolve_cache_len` so
it is unit-testable, and the README's ANE Optimizations section
documents the new option.

### Memory savings example

For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with
`max_context_len=8192` in fp16, dropping the cache from 8160 to
4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB.
@john-rocky john-rocky requested a review from metascroy as a code owner May 1, 2026 05:30
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 1, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19250

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 11 Awaiting Approval

As of commit 8a2dfb5 with merge base 94d2881 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant